[Nonlinear.ReverseAd.Coloring] fix acyclic coloring algorithm#2898
[Nonlinear.ReverseAd.Coloring] fix acyclic coloring algorithm#2898
Conversation
|
It was a bit hard to reason about the two subsequent diffs, so here's the combined one of what the original PR should have been: diff --git a/src/Nonlinear/ReverseAD/Coloring/Coloring.jl b/src/Nonlinear/ReverseAD/Coloring/Coloring.jl
index f847ce4e..5f00563a 100644
--- a/src/Nonlinear/ReverseAD/Coloring/Coloring.jl
+++ b/src/Nonlinear/ReverseAD/Coloring/Coloring.jl
@@ -6,8 +6,7 @@
module Coloring
-import DataStructures
-
+include("IntDisjointSet.jl")
include("topological_sort.jl")
"""
@@ -154,7 +153,7 @@ function _prevent_cycle(
forbiddenColors,
color,
)
- er = DataStructures.find_root!(S, e_idx2)
+ er = _find_root!(S, e_idx2)
@inbounds first = firstVisitToTree[er]
p = first.source # but this depends on the order?
q = first.target
@@ -172,29 +171,11 @@ function _grow_star(v, w, e_idx, firstNeighbor, color, S)
@inbounds if p != v
firstNeighbor[color[w]] = _Edge(e_idx, v, w)
else
- union!(S, e_idx, e.index)
- end
- return
-end
-
-function _merge_trees(eg, eg1, S)
- e1 = DataStructures.find_root!(S, eg)
- e2 = DataStructures.find_root!(S, eg1)
- if e1 != e2
- union!(S, eg, eg1)
+ _union!(S, e_idx, e.index)
end
return
end
-# Work-around a deprecation in [email protected]
-function _IntDisjointSet(n)
- @static if isdefined(DataStructures, :IntDisjointSet)
- return DataStructures.IntDisjointSet(n)
- else
- return DataStructures.IntDisjointSets(n) # COV_EXCL_LINE
- end
-end
-
"""
acyclic_coloring(g::UndirectedGraph)
@@ -214,7 +195,6 @@ function acyclic_coloring(g::UndirectedGraph)
firstNeighbor = _Edge[]
firstVisitToTree = fill(_Edge(0, 0, 0), _num_edges(g))
color = fill(0, _num_vertices(g))
- # disjoint set forest of edges in the graph
S = _IntDisjointSet(_num_edges(g))
@inbounds for v in 1:_num_vertices(g)
n_neighbor = _num_neighbors(v, g)
@@ -293,7 +273,7 @@ function acyclic_coloring(g::UndirectedGraph)
continue
end
if color[x] == color[v]
- _merge_trees(e_idx, e2_idx, S)
+ _union!(S, e_idx, e2_idx)
end
end
end
diff --git a/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl b/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl
new file mode 100644
index 00000000..4fb6ea26
--- /dev/null
+++ b/src/Nonlinear/ReverseAD/Coloring/IntDisjointSet.jl
@@ -0,0 +1,56 @@
+# Copyright (c) 2017: Miles Lubin and contributors
+# Copyright (c) 2017: Google Inc.
+# Copyright (c) 2024: Guillaume Dalle and Alexis Montoison
+#
+# Use of this source code is governed by an MIT-style license that can be found
+# in the LICENSE.md file or at https://opensource.org/licenses/MIT.
+
+# The code in this file was taken from
+# https://github.com/gdalle/SparseMatrixColorings.jl/blob/main/src/Forest.jl
+#
+# It was copied at the suggestion of Alexis in his JuMP-dev 2025 talk.
+#
+# @odow made minor changes to match MOI coding styles.
+#
+# x-ref https://github.com/gdalle/SparseMatrixColorings.jl/pull/190
+
+mutable struct _IntDisjointSet
+ # current number of distinct trees in the S
+ number_of_trees::Int
+ # vector storing the index of a parent in the tree for each edge, used in
+ # union-find operations
+ parents::Vector{Int}
+ # vector approximating the depth of each tree to optimize path compression
+ ranks::Vector{Int}
+
+ _IntDisjointSet(n::Integer) = new(n, collect(1:n), zeros(Int, n))
+end
+
+function _find_root!(S::_IntDisjointSet, x::Integer)
+ p = S.parents[x]
+ if S.parents[p] != p
+ S.parents[x] = p = _find_root!(S, p)
+ end
+ return p
+end
+
+function _root_union!(S::_IntDisjointSet, x::Int, y::Int)
+ rank1, rank2 = S.ranks[x], S.ranks[y]
+ if rank1 < rank2
+ x, y = y, x
+ elseif rank1 == rank2
+ S.ranks[x] += 1
+ end
+ S.parents[y] = x
+ S.number_of_trees -= 1
+ return
+end
+
+function _union!(S, x::Int, y::Int)
+ root_x = _find_root!(S, x)
+ root_y = _find_root!(S, y)
+ if root_x != root_y
+ _root_union!(S, root_x, root_y)
+ end
+ return
+end |
|
As described at #2882 (comment), for #2882 I ran all the PureJuMP models in OptimizationModels and verified the hessians before and after the change. I think we should do that here. |
|
Yip. I'll also re-run the solver tests. |
@odow It is because the edge and the star don't have a "shared" edge so we can avoid to |
|
I see the issue, you checked that the "roots" are different but you merged the trees with the edge indices instead of the root indices. |
|
Now there are many large differences between SCT and JuMP: ┌ Warning: Inconsistencies were detected
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:74
┌ Warning: Inconsistency for Jacobian of hs117: SCT (75 nz) ⊃ JuMP (62 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:76
┌ Warning: Inconsistency for Jacobian of lincon: SCT (19 nz) ⊃ JuMP (17 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:76
┌ Warning: Inconsistency for Hessian of argauss: SCT (8 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of biggs5: SCT (9 nz) ⊂ JuMP (12 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of biggs6: SCT (9 nz) ⊂ JuMP (12 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of britgas: SCT (1087 nz) ⊂ JuMP (1111 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of chain: SCT (75 nz) ⊂ JuMP (100 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of channel: SCT (696 nz) ⊃ JuMP (672 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of dixmaane: SCT (493 nz) ⊃ JuMP (297 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of dixmaani: SCT (493 nz) ⊃ JuMP (297 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of dixmaanm: SCT (493 nz) ⊃ JuMP (297 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs114: SCT (19 nz) ⊂ JuMP (21 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs119: SCT (256 nz) ⊃ JuMP (76 nz)
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs250: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs251: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs36: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs37: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs40: SCT (15 nz) ⊂ JuMP (16 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs41: SCT (6 nz) ⊂ JuMP (9 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs45: SCT (20 nz) ⊂ JuMP (25 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs56: SCT (10 nz) ⊂ JuMP (13 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs68: SCT (9 nz) ⊂ JuMP (10 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs69: SCT (9 nz) ⊂ JuMP (10 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs87: SCT (9 nz) ⊂ JuMP (11 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of hs93: SCT (34 nz) ⊂ JuMP (36 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of polygon1: SCT (550 nz) ⊂ JuMP (600 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of polygon2: SCT (350 nz) ⊂ JuMP (400 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79
┌ Warning: Inconsistency for Hessian of robotarm: SCT (252 nz) ⊂ JuMP (276 nz) [diagonal difference only]
└ @ Main ~/git/adrhill/SparseConnectivityTracer.jl/benchmark/main.jl:79 |
Co-authored-by: Alexis Montoison <[email protected]>
|
Okay. I ran the following. Running it also triggered the assert, so if I had re-run the Optimisation tests after #2885 we would have found this sooner. using Revise
using JuMP, OptimizationProblems, SparseArrays
function compute_random_hessian(name::String)
println(name)
try
model = getfield(OptimizationProblems.PureJuMP, Symbol(name))()
return _compute_random_hessian(model)
catch err
if err isa MOI.UnsupportedNonlinearOperator
return nothing
end
rethrow(err)
end
end
function _compute_random_hessian(model::Model)
rows = Any[]
nlp = MOI.Nonlinear.Model()
for (F, S) in list_of_constraint_types(model)
for ci in all_constraints(model, F, S)
push!(rows, ci)
object = constraint_object(ci)
MOI.Nonlinear.add_constraint(nlp, object.func, object.set)
end
end
MOI.Nonlinear.set_objective(nlp, objective_function(model))
x = all_variables(model)
backend = MOI.Nonlinear.SparseReverseMode()
evaluator = MOI.Nonlinear.Evaluator(nlp, backend, index.(x))
MOI.initialize(evaluator, [:Hess])
hessian_sparsity = MOI.hessian_lagrangian_structure(evaluator)
I = [i for (i, _) in hessian_sparsity]
J = [j for (_, j) in hessian_sparsity]
V = zeros(length(hessian_sparsity))
primal = sin.(1:length(x))
dual = cos.(1:length(rows))
MOI.eval_hessian_lagrangian(evaluator, V, primal, 1.234, dual)
return SparseArrays.sparse(I, J, V, length(x), length(x))
end
log = Dict(
name => compute_random_hessian(name)
for name in OptimizationProblems.meta[!, :name]
)
open("/tmp/log.txt", "w") do io
for name in OptimizationProblems.meta[!, :name]
H = log[name]
if H === nothing
println(io, name)
else
println(io, name, " ", nnz(H), " ", hash(H))
end
end
endfor both this PR and for [email protected]. There are a few differences: % diff log_pr.txt log_1.46.0.txt
54c54
< britgas 759 15207358802602782180
---
> britgas 911 7370492992218032816
132c132
< hs107 17 6271903269068003906
---
> hs107 17 5313540568828064766
140c140
< hs114 15 14596070845963864172
---
> hs114 16 14723025369218343364
335c335
< polygon3 200 5952677266859018177
---
> polygon3 200 1326136011177673121
366,368c366,368
< triangle_deer 15454 1481367466358415038
< triangle_pacman 9510 4334017736271540586
< triangle_turtle 31682 1790788003688105916
---
> triangle_deer 15454 9191873022530188051
> triangle_pacman 9510 2601781815016602201
> triangle_turtle 31682 4301116183540000934There are a few that have numerical differences, but that's just to some small tolerance. Some have reduced non-zeros in the correct places. screen_recording.mov |
Closes #2897
The issue was here:
https://github.com/jump-dev/MathOptInterface.jl/pull/2885/files#diff-42ba053a9aef9ff60f40635dd168d724ad628a29f2c47038a2d78d5b12b4c680R174
I assumed @amontoison had just renamed some things, but I didn't make the corresponding change to:
https://github.com/gdalle/SparseMatrixColorings.jl/blob/9b52faccdaae41d3ce27158434cc5597d1a61a36/src/coloring.jl#L390-L392
here's the original upstream of
Base.union!:https://github.com/JuliaCollections/DataStructures.jl/blob/b67c498a11402f6c18e5e74c69d95e2621f75aa0/src/disjoint_set.jl#L83-L94
There's still one small difference.
In SparseMatrixColorings.jl, the code to merge two trees is
but in the DataStructures version of JuMP it was equivalent to